{{>licenseInfo}}
package {{invokerPackage}}.auth;

import io.micronaut.context.BeanContext;
import io.micronaut.core.annotation.NonNull;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.util.CollectionUtils;
import io.micronaut.core.util.StringUtils;
import io.micronaut.core.util.Toggleable;
import io.micronaut.http.HttpRequest;
import io.micronaut.http.HttpResponse;
import io.micronaut.http.MutableHttpRequest;
import io.micronaut.http.annotation.Filter;
import io.micronaut.http.filter.ClientFilterChain;
import io.micronaut.http.filter.HttpClientFilter;
import io.micronaut.inject.qualifiers.Qualifiers;
import io.micronaut.security.oauth2.client.clientcredentials.ClientCredentialsClient;
import io.micronaut.security.oauth2.client.clientcredentials.ClientCredentialsConfiguration;
import io.micronaut.security.oauth2.client.clientcredentials.propagation.ClientCredentialsHttpClientFilter;
import io.micronaut.security.oauth2.client.clientcredentials.propagation.ClientCredentialsTokenPropagator;
import io.micronaut.security.oauth2.configuration.OauthClientConfiguration;
import io.micronaut.security.oauth2.endpoint.token.response.TokenResponse;
import org.openapitools.auth.configuration.ConfigurableAuthorization;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Generated;


{{>generatedAnnotation}}
@Filter(Filter.MATCH_ALL_PATTERN)
public class AuthorizationFilter implements HttpClientFilter {
    private static final Logger LOG = LoggerFactory.getLogger(ClientCredentialsHttpClientFilter.class);

    private final BeanContext beanContext;
    private final Map<String, OauthClientConfiguration> clientConfigurationByName;

    ClientCredentialsTokenPropagator defaultTokenPropagator;
    private final Map<String, ClientCredentialsTokenPropagator> tokenPropagatorByName;
    private final Map<String, ClientCredentialsClient> clientCredentialsClientByName;

    public final Map<String, ConfigurableAuthorization> authorizationsByName;

    public AuthorizationFilter(
            BeanContext beanContext,
            Stream<OauthClientConfiguration> clientConfigurations,
            ClientCredentialsTokenPropagator defaultTokenPropagator,
            Stream<ConfigurableAuthorization> configurableAuthorizations
    ) {
        this.beanContext = beanContext;
        this.clientConfigurationByName = clientConfigurations
                .filter(Toggleable::isEnabled)
                .collect(Collectors.toMap(OauthClientConfiguration::getName, v -> v));
        this.defaultTokenPropagator = defaultTokenPropagator;
        this.tokenPropagatorByName = new HashMap<>();
        this.clientCredentialsClientByName = new HashMap<>();
        this.authorizationsByName = configurableAuthorizations
                .collect(Collectors.toMap(ConfigurableAuthorization::getName, v -> v));
    }

    @Override
    public Publisher<? extends HttpResponse<?>> doFilter(
            @NonNull MutableHttpRequest<?> request,
            @NonNull ClientFilterChain chain
    ) {
        List<?> names = request.getAttribute(AuthorizationBinder.AUTHORIZATION_NAMES, List.class).orElse(null);
        if (CollectionUtils.isNotEmpty(names)) {
            List<Publisher<HttpRequest>> authorizers = new ArrayList<>(names.size());

            for (Object nameObject: names) {
                if (!(nameObject instanceof String)) {
                    continue;
                }
                String name = (String) nameObject;

                // Check if other authorizations have the key
                if (authorizationsByName.containsKey(name)) {
                    ConfigurableAuthorization authorizer = authorizationsByName.get(name);
                    authorizers.add(Mono.fromCallable(() -> {
                        authorizer.applyAuthorization(request);
                        return request;
                    }));
                    continue;
                }

                // Perform OAuth authorization
                OauthClientConfiguration clientConfiguration = clientConfigurationByName.get(name);
                if (clientConfiguration == null) {
                    continue;
                }

                ClientCredentialsClient clientCredentialsClient = getClientCredentialsClient(name);
                if (clientCredentialsClient == null) {
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Could not retrieve client credentials client for OAuth 2.0 client {}", name);
                    }
                    continue;
                }

                ClientCredentialsTokenPropagator tokenHandler = getTokenPropagator(name);
                Flux<HttpRequest> authorizer = Flux.from(clientCredentialsClient
                        .requestToken(getScope(clientConfiguration)))
                        .map(TokenResponse::getAccessToken)
                        .map(accessToken -> {
                            if (StringUtils.isNotEmpty(accessToken)) {
                                tokenHandler.writeToken(request, accessToken);
                            }
                            return request;
                        });
                authorizers.add(authorizer);
            }

            return Flux.concat(authorizers)
                    .switchMap(v -> chain.proceed(request));
        }


        return chain.proceed(request);
    }

    protected ClientCredentialsTokenPropagator getTokenPropagator(String name) {
        ClientCredentialsTokenPropagator tokenPropagator = tokenPropagatorByName.get(name);
        if (tokenPropagator == null) {
            tokenPropagator = beanContext.findBean(ClientCredentialsTokenPropagator.class, Qualifiers.byName(name))
                    .orElse(defaultTokenPropagator);
            if (tokenPropagator != null) {
                tokenPropagatorByName.put(name, tokenPropagator);
            }
        }
        return tokenPropagator;
    }

    protected ClientCredentialsClient getClientCredentialsClient(String name) {
        ClientCredentialsClient client = clientCredentialsClientByName.get(name);
        if (client == null) {
            client = beanContext.findBean(ClientCredentialsClient.class, Qualifiers.byName(name)).orElse(null);
            if (client != null) {
                clientCredentialsClientByName.put(name, client);
            }
        }
        return client;
    }

    @Nullable
    protected String getScope(@NonNull OauthClientConfiguration oauthClient) {
        return oauthClient.getClientCredentials().flatMap(ClientCredentialsConfiguration::getScope).orElse(null);
    }
}
